{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.9459288 , -0.32437435,  0.73634136], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "\n",
    "        #一局游戏最多走N步\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdM0lEQVR4nO3dfXBTdb4G8OekadLXk9LSJvTSSneKYheKUqBE3XGudKnadX2pd3a9jHZYVkcNDIjDrN0VnHV2pizOrKu7is7srHjnjnYHZ6srC2pvwaJrKCVQLQUqu6LtpSTlrSdpoUmb/O4f2nMNVEzaJr8En8/MmTHnfNM+QfKQc05yogghBIiI4swgOwARfTexfIhICpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISApp5fPCCy9g1qxZSEtLQ2VlJfbt2ycrChFJIKV8/vKXv2DdunV46qmncODAAcyfPx/V1dXo7++XEYeIJFBkfLC0srISixYtwh//+EcAQCgUQlFREVavXo0nnnjiW+8fCoXQ19eH7OxsKIoS67hEFCEhBHw+HwoLC2EwXP61jTFOmXSBQAAulwv19fX6OoPBgKqqKjidznHv4/f74ff79dsnTpxAWVlZzLMS0cT09vZi5syZl52Je/mcPn0awWAQVqs1bL3VasXRo0fHvU9DQwN+/etfX7K+t7cXqqrGJCcRRc/r9aKoqAjZ2dnfOhv38pmI+vp6rFu3Tr899gBVVWX5ECWgSA6HxL18pk+fjpSUFHg8nrD1Ho8HNptt3PuYzWaYzeZ4xCOiOIn72S6TyYSKigq0tLTo60KhEFpaWmC32+Mdh4gkkbLbtW7dOtTV1WHhwoVYvHgxfv/732NoaAgrVqyQEYeIJJBSPj/5yU9w6tQpbNy4EW63G9dddx3eeeedSw5CE9GVS8r7fCbL6/XCYrFA0zQecCZKINE8N/nZLiKSguVDRFKwfIhICpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIilYPkQkBcuHiKRg+RCRFCwfIpKC5UNEUrB8iEgKlg8RScHyISIpWD5EJAXLh4ikYPkQkRQsHyKSguVDRFKwfIhICpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISIqoy2fPnj244447UFhYCEVR8Oabb4ZtF0Jg48aNmDFjBtLT01FVVYVjx46FzZw9exbLly+HqqrIycnBypUrMTg4OKkHQkTJJeryGRoawvz58/HCCy+Mu33z5s14/vnn8dJLL6GtrQ2ZmZmorq7G8PCwPrN8+XJ0dXWhubkZ27dvx549e/DQQw9N/FEQUfIRkwBANDU16bdDoZCw2WzimWee0dcNDAwIs9ksXn/9dSGEEIcPHxYARHt7uz6zc+dOoSiKOHHiRES/V9M0AUBomjaZ+EQ0xaJ5bk7pMZ/jx4/D7XajqqpKX2exWFBZWQmn0wkAcDqdyMnJwcKFC/WZqqoqGAwGtLW1jftz/X4/vF5v2EJEyW1Ky8ftdgMArFZr2Hqr1apvc7vdKCgoCNtuNBqRm5urz1ysoaEBFotFX4qKiqYyNhFJkBRnu+rr66Fpmr709vbKjkREkzSl5WOz2QAAHo8nbL3H49G32Ww29Pf3h20fHR3F2bNn9ZmLmc1mqKoathBRcpvS8ikpKYHNZkNLS4u+zuv1oq2tDXa7HQBgt9sxMDAAl8ulz+zatQuhUAiVlZVTGYeIEpgx2jsMDg7in//8p377+PHj6OjoQG5uLoqLi7F27Vr85je/wezZs1FSUoINGzagsLAQd911FwDg2muvxa233ooHH3wQL730EkZGRrBq1Sr89Kc/RWFh4ZQ9MCJKcNGeStu9e7cAcMlSV1cnhPjydPuGDRuE1WoVZrNZLF26VHR3d4f9jDNnzoj77rtPZGVlCVVVxYoVK4TP54s4A0+1EyWmaJ6bihBCSOy+CfF6vbBYLNA0jcd/iBJINM/NpDjbRURXHpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIilYPkQkBcuHiKRg+RCRFCwfIpKC5UNEUrB8iEgKlg8RScHyISIpWD5EJAXLh4ikYPkQkRQsHyKSguVDRFKwfIhICpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIimiKp+GhgYsWrQI2dnZKCgowF133YXu7u6wmeHhYTgcDuTl5SErKwu1tbXweDxhMz09PaipqUFGRgYKCgqwfv16jI6OTv7REFHSiKp8Wltb4XA4sHfvXjQ3N2NkZATLli3D0NCQPvPYY4/h7bffxrZt29Da2oq+vj7cc889+vZgMIiamhoEAgF89NFHePXVV7F161Zs3Lhx6h4VESU+MQn9/f0CgGhtbRVCCDEwMCBSU1PFtm3b9JkjR44IAMLpdAohhNixY4cwGAzC7XbrM1u2bBGqqgq/3x/R79U0TQAQmqZNJj4RTbFonpuTOuajaRoAIDc3FwDgcrkwMjKCqqoqfWbOnDkoLi6G0+kEADidTsybNw9Wq1Wfqa6uhtfrRVdX17i/x+/3w+v1hi1ElNwmXD6hUAhr167FjTfeiLlz5wIA3G43TCYTcnJywmatVivcbrc+8/XiGds+tm08DQ0NsFgs+lJUVDTR2ESUICZcPg6HA4cOHUJjY+NU5hlXfX09NE3Tl97e3pj/TiKKLeNE7rRq1Sps374de/bswcyZM/X1NpsNgUAAAwMDYa9+PB4PbDabPrNv376wnzd2Nmxs5mJmsxlms3kiUYkoQUX1ykcIgVWrVqGpqQm7du1CSUlJ2PaKigqkpqaipaVFX9fd3Y2enh7Y7XYAgN1uR2dnJ/r7+/WZ5uZmqKqKsrKyyTwWIkoiUb3ycTgceO211/DWW28hOztbP0ZjsViQnp4Oi8WClStXYt26dcjNzYWqqli9ejXsdjuWLFkCAFi2bBnKyspw//33Y/PmzXC73XjyySfhcDj46obouySa02gAxl1eeeUVfebChQvi0UcfFdOmTRMZGRni7rvvFidPngz7OZ9//rm47bbbRHp6upg+fbp4/PHHxcjISMQ5eKqdKDFF89xUhBBCXvVNjNfrhcVigaZpUFVVdhwi+ko0z01+touIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIilYPkQkBcuHiKSY0KfaiaaCCIUQHBpCKBCAkpKClPR0KCYTFEWRHY3igOVDcSeEwMjZszi1cye0ffsQOH0aBrMZGaWlKLj9dmSXl0NJSZEdk2KM5UNxJYSAv68Pnz/3HIa6u4GvPloYHByEduYMBjs7UfjAA8ivrmYBXeF4zIfiKnj+PHpefhlDR49ChEI45/dj/+nTOOb1IiQEgufP48R//Re0AweQhJ95pijwlQ/FldbeDt8nn0AIgZ6hIWw4eBDdmoZMoxE/v/pq/KSkBDh/Hp6mJmTPnYuU9HTZkSlG+MqH4mrw8GEgFIIA8NvOThweGEBQCHhHRvDHI0dw6Nw5AMD5zz5D8MIFuWEpplg+JI13ZCTsdiAUgj8Y/PJGMIiR06clpKJ4YfmQFAqAf7fZYPzaafWrVRVXZWUBAEJ+Pwb27oUIhSQlpFjjMR+Kq5TMTACAoiioKy1Fdmoq/ufkScxIT8eDV1+NgrQ0fVZrb4ft3nuRkpEhKy7FEMuH4ir3ppvQ/9ZbEKOjMBoM+I9Zs3DvrFkYe/3z9TcY+j0enP/sM2R9//t84+EViLtdFFfmGTOQUVqq31YUBQZFgfLV8nWh4WFoLpf+XiC6srB8KK5SMjKQPW9exPNaWxuC58/HMBHJwvKhuMtZsgSKMbI9/uG+Pvg+/phvOLwCsXwo7tKvuips1+uyQiGcczpjG4ikYPlQ3CmpqbAsWhTx/NCnn2J0YCB2gUgKlg/FnaIoyC4vhyHCr8cOuN0Y+vRT7npdYVg+JEVmaSmyysoinj+ze3cM05AMLB+Sw2BAzg03RDx+/tgxftziCsPyISkURYFaXg7jtGkRzQdOncLQsWMxTkXxxPIhaVLz85E5e3bE82f37IEY++ApJT2WD0mjpKTAsnAhYIjsr+Hg4cMIcNfrisHyIWkURYGlogLGrz7J/m1GvV4MdnXxrNcVguVDUhlzcpAV6cctQiEM7NsHMToa21AUFywfkkoxGqFed13Eu16+zk4ETp2KbSiKC5YPSaUoCiwLFujX+fk2QZ8P5z78kLteVwCWD0mXOm3al69+IuR1uRDy+2MXiOKC5UPSKUZjVGe9LvT0IODxxDgVxVpU5bNlyxaUl5dDVVWoqgq73Y6dO3fq24eHh+FwOJCXl4esrCzU1tbCc9Ffkp6eHtTU1CAjIwMFBQVYv349RnkA8Tsvq6wMxuzsiGaDQ0PQXC7ueiW5qMpn5syZ2LRpE1wuF/bv349bbrkFd955J7q6ugAAjz32GN5++21s27YNra2t6Ovrwz333KPfPxgMoqamBoFAAB999BFeffVVbN26FRs3bpzaR0VJx5Sfj6zvfz/i+XP/+AfPeiU5RUzyn4/c3Fw888wzuPfee5Gfn4/XXnsN9957LwDg6NGjuPbaa+F0OrFkyRLs3LkTP/rRj9DX1wer1QoAeOmll/CLX/wCp06dgslkGvd3+P1++L+2j+/1elFUVARN06Cq6mTiUwI5949/4LPf/jai2ZTMTMx++mlklJby+s4JxOv1wmKxRPTcnPAxn2AwiMbGRgwNDcFut8PlcmFkZARVVVX6zJw5c1BcXAznVxeDcjqdmDdvnl48AFBdXQ2v16u/ehpPQ0MDLBaLvhQVFU00NiWwjNmzkZqXF9FscGgIvs7OGCeiWIq6fDo7O5GVlQWz2YyHH34YTU1NKCsrg9vthslkQk5OTti81WqF2+0GALjd7rDiGds+tu2b1NfXQ9M0fent7Y02NiUBU35+5Fc4BHDugw8gLvriQUoeUX91zjXXXIOOjg5omoY33ngDdXV1aG1tjUU2ndlshjnCC09Rcsuvroa2fz8QwQdIh0+cwPnPPkPmNddw1ysJRf3Kx2QyobS0FBUVFWhoaMD8+fPx3HPPwWazIRAIYOCiy116PB7YbDYAgM1mu+Ts19jtsRn67lIUBZlz5iBtxoyI5kPDw/AeOBDjVBQrk36fTygUgt/vR0VFBVJTU9HS0qJv6+7uRk9PD+x2OwDAbrejs7MT/f39+kxzczNUVUVZFFe1oytXSmYm1AULIp7X2tsRGh6OYSKKlah2u+rr63HbbbehuLgYPp8Pr732Gt5//328++67sFgsWLlyJdatW4fc3FyoqorVq1fDbrdjyZIlAIBly5ahrKwM999/PzZv3gy3240nn3wSDoeDu1WkUxcswKl33oEIBL519sIXX+D8v/6F7Llz45CMplJU5dPf348HHngAJ0+ehMViQXl5Od5991388Ic/BAA8++yzMBgMqK2thd/vR3V1NV588UX9/ikpKdi+fTseeeQR2O12ZGZmoq6uDk8//fTUPipKWoqiIGvOHJgLCjD8v//7rfNidBSnm5v5lcpJaNLv85EhmvcSUPIRoRD6/vu/4X7jjYjmTTYbrtm0Cabc3Bgno28Tl/f5EMWMoiD7uutgSEuLaDzQ348Lx4/HOBRNNZYPJRxFUZBZWgrT9OmR3SEU+vLjFqFQbIPRlGL5UEIypKfDUlkZ8bz34EGMaloME9FUY/lQQlIUBTmLF0f8raajmobBI0f4SfckwvKhhJVWVARzYWFEs2J0FN6DBwHueiUNlg8lrJTMTKjz50c8r7W3Y9Tni2EimkosH0pYiqJg2g9+AMUY2dvRRjUNvo8/5q5XkmD5UEIzFxYi/Xvfi2hWBIPQuOuVNFg+lNBSMjKQHcUVDn0ff4wRnvVKCiwfSmiKomDajTdGfHH5kXPn4Ovo4K5XEmD5UMJLKy5GZqQXGQuFcPq997jrlQRYPpTwUtLSkH399RHPX+jtxQVe7TLhsXwoKeT+4AdIifBDxEGfD4OHD3PXK8GxfCgppBUWIuOqqyKeP/fBB9z1SnAsH0oOKSmYdtNNEY+f/9e/cKGnJ4aBaLJYPpQUFEVBdnk5jBd9O8o3CQ0Pw9vREdNMNDksH0oaJqsVGRG+4RAAho4ejWEamiyWDyUNJSUFOZWVAC+XekVg+VDSUBQFakUFjLx07hWB5UNJxZidHfFxH5ZUYmP5UFIxpKUhv7oaSEm57FxKZibyb701TqloIlg+lFQURUHeLbcg96abvvHzXgazGTP+8z+RXlIS53QUjai/q51ItpSMDMz8+c9hVFWcef99BIeGgFAIitGI1OnTYautRd4tt0CJ8MOoJAfLh5JSqsWCf1uxAnlVVRj69FMEh4ZgystDZlkZTHl5LJ4kwPKhpGUwGpFRUoIM7l4lJf7zQERSsHyISAqWDxFJwfIhIilYPkQkBcuHiKRg+RCRFCwfIpKC5UNEUrB8iEiKSZXPpk2boCgK1q5dq68bHh6Gw+FAXl4esrKyUFtbC4/HE3a/np4e1NTUICMjAwUFBVi/fj1GR0cnE4WIksyEy6e9vR0vv/wyysvLw9Y/9thjePvtt7Ft2za0trair68P99xzj749GAyipqYGgUAAH330EV599VVs3boVGzdunPijIKLkIybA5/OJ2bNni+bmZnHzzTeLNWvWCCGEGBgYEKmpqWLbtm367JEjRwQA4XQ6hRBC7NixQxgMBuF2u/WZLVu2CFVVhd/vH/f3DQ8PC03T9KW3t1cAEJqmTSQ+EcWIpmkRPzcn9MrH4XCgpqYGVVVVYetdLhdGRkbC1s+ZMwfFxcVwOp0AAKfTiXnz5sFqteoz1dXV8Hq96OrqGvf3NTQ0wGKx6EtRUdFEYhNRAom6fBobG3HgwAE0NDRcss3tdsNkMiHnomvsWq1WuN1ufebrxTO2fWzbeOrr66Fpmr708nu4iZJeVNfz6e3txZo1a9Dc3Iy0tLRYZbqE2WyG2WyO2+8jotiL6pWPy+VCf38/FixYAKPRCKPRiNbWVjz//PMwGo2wWq0IBAIYGBgIu5/H44HNZgMA2Gy2S85+jd0emyGiK19U5bN06VJ0dnaio6NDXxYuXIjly5fr/52amoqWlhb9Pt3d3ejp6YHdbgcA2O12dHZ2or+/X59pbm6GqqooKyuboodFRIkuqt2u7OxszJ07N2xdZmYm8vLy9PUrV67EunXrkJubC1VVsXr1atjtdixZsgQAsGzZMpSVleH+++/H5s2b4Xa78eSTT8LhcHDXiug7ZMqv4fzss8/CYDCgtrYWfr8f1dXVePHFF/XtKSkp2L59Ox555BHY7XZkZmairq4OTz/99FRHIaIEpgghhOwQ0fJ6vbBYLNA0DSq/lZIoYUTz3ORnu4hICpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIilYPkQkBcuHiKRg+RCRFCwfIpKC5UNEUrB8iEgKlg8RScHyISIpWD5EJAXLh4ikYPkQkRQsHyKSguVDRFKwfIhICpYPEUnB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIilYPkQkBcuHiKQwyg4wEUIIAIDX65WchIi+buw5OfYcvZykLJ8zZ84AAIqKiiQnIaLx+Hw+WCyWy84kZfnk5uYCAHp6er71ASYar9eLoqIi9Pb2QlVV2XEixtzxlay5hRDw+XwoLCz81tmkLB+D4ctDVRaLJan+x3ydqqpJmZ254ysZc0f6goAHnIlICpYPEUmRlOVjNpvx1FNPwWw2y44StWTNztzxlay5o6GISM6JERFNsaR85UNEyY/lQ0RSsHyISAqWDxFJwfIhIimSsnxeeOEFzJo1C2lpaaisrMS+ffuk5tmzZw/uuOMOFBYWQlEUvPnmm2HbhRDYuHEjZsyYgfT0dFRVVeHYsWNhM2fPnsXy5cuhqipycnKwcuVKDA4OxjR3Q0MDFi1ahOzsbBQUFOCuu+5Cd3d32Mzw8DAcDgfy8vKQlZWF2tpaeDyesJmenh7U1NQgIyMDBQUFWL9+PUZHR2OWe8uWLSgvL9ff/Wu327Fz586EzjyeTZs2QVEUrF27NumyTwmRZBobG4XJZBJ//vOfRVdXl3jwwQdFTk6O8Hg80jLt2LFD/OpXvxJ//etfBQDR1NQUtn3Tpk3CYrGIN998U3z88cfixz/+sSgpKREXLlzQZ2699VYxf/58sXfvXvHBBx+I0tJScd9998U0d3V1tXjllVfEoUOHREdHh7j99ttFcXGxGBwc1GcefvhhUVRUJFpaWsT+/fvFkiVLxA033KBvHx0dFXPnzhVVVVXi4MGDYseOHWL69Omivr4+Zrn/9re/ib///e/i008/Fd3d3eKXv/ylSE1NFYcOHUrYzBfbt2+fmDVrligvLxdr1qzR1ydD9qmSdOWzePFi4XA49NvBYFAUFhaKhoYGian+38XlEwqFhM1mE88884y+bmBgQJjNZvH6668LIYQ4fPiwACDa29v1mZ07dwpFUcSJEyfilr2/v18AEK2trXrO1NRUsW3bNn3myJEjAoBwOp1CiC+L12AwCLfbrc9s2bJFqKoq/H5/3LJPmzZN/OlPf0qKzD6fT8yePVs0NzeLm2++WS+fZMg+lZJqtysQCMDlcqGqqkpfZzAYUFVVBafTKTHZNzt+/DjcbndYZovFgsrKSj2z0+lETk4OFi5cqM9UVVXBYDCgra0tblk1TQPw/1cNcLlcGBkZCcs+Z84cFBcXh2WfN28erFarPlNdXQ2v14uurq6YZw4Gg2hsbMTQ0BDsdntSZHY4HKipqQnLCCTHn/dUSqpPtZ8+fRrBYDDsDx4ArFYrjh49KinV5bndbgAYN/PYNrfbjYKCgrDtRqMRubm5+kyshUIhrF27FjfeeCPmzp2r5zKZTMjJybls9vEe29i2WOns7ITdbsfw8DCysrLQ1NSEsrIydHR0JGxmAGhsbMSBAwfQ3t5+ybZE/vOOhaQqH4odh8OBQ4cO4cMPP5QdJSLXXHMNOjo6oGka3njjDdTV1aG1tVV2rMvq7e3FmjVr0NzcjLS0NNlxpEuq3a7p06cjJSXlkqP/Ho8HNptNUqrLG8t1ucw2mw39/f1h20dHR3H27Nm4PK5Vq1Zh+/bt2L17N2bOnKmvt9lsCAQCGBgYuGz28R7b2LZYMZlMKC0tRUVFBRoaGjB//nw899xzCZ3Z5XKhv78fCxYsgNFohNFoRGtrK55//nkYjUZYrdaEzR4LSVU+JpMJFRUVaGlp0deFQiG0tLTAbrdLTPbNSkpKYLPZwjJ7vV60tbXpme12OwYGBuByufSZXbt2IRQKobKyMmbZhBBYtWoVmpqasGvXLpSUlIRtr6ioQGpqalj27u5u9PT0hGXv7OwMK8/m5maoqoqysrKYZb9YKBSC3+9P6MxLly5FZ2cnOjo69GXhwoVYvny5/t+Jmj0mZB/xjlZjY6Mwm81i69at4vDhw+Khhx4SOTk5YUf/483n84mDBw+KgwcPCgDid7/7nTh48KD44osvhBBfnmrPyckRb731lvjkk0/EnXfeOe6p9uuvv160tbWJDz/8UMyePTvmp9ofeeQRYbFYxPvvvy9OnjypL+fPn9dnHn74YVFcXCx27dol9u/fL+x2u7Db7fr2sVO/y5YtEx0dHeKdd94R+fn5MT31+8QTT4jW1lZx/Phx8cknn4gnnnhCKIoi3nvvvYTN/E2+frYr2bJPVtKVjxBC/OEPfxDFxcXCZDKJxYsXi71790rNs3v3bgHgkqWurk4I8eXp9g0bNgir1SrMZrNYunSp6O7uDvsZZ86cEffdd5/IysoSqqqKFStWCJ/PF9Pc42UGIF555RV95sKFC+LRRx8V06ZNExkZGeLuu+8WJ0+eDPs5n3/+ubjttttEenq6mD59unj88cfFyMhIzHL/7Gc/E1dddZUwmUwiPz9fLF26VC+eRM38TS4un2TKPlm8ng8RSZFUx3yI6MrB8iEiKVg+RCQFy4eIpGD5EJEULB8ikoLlQ0RSsHyISAqWDxFJwfIhIilYPkQkxf8B1iz0LwVxArQAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.figure(figsize=(3, 3))\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env.observation_space= Box([-1. -1. -8.], [1. 1. 8.], (3,), float32)\n",
      "env.action_space= Box(-2.0, 2.0, (1,), float32)\n",
      "state= [-0.966843    0.25537148 -0.4515908 ]\n",
      "action= [0.48694044]\n",
      "next_state= [-0.9644128   0.2644012  -0.18702114]\n",
      "reward= -8.33439976337045\n",
      "done= False\n"
     ]
    }
   ],
   "source": [
    "#认识游戏环境\n",
    "def test_env():\n",
    "    print('env.observation_space=', env.observation_space)\n",
    "    print('env.action_space=', env.action_space)\n",
    "\n",
    "    state = env.reset()\n",
    "    action = env.action_space.sample()\n",
    "    next_state, reward, done, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    print('action=', action)\n",
    "    print('next_state=', next_state)\n",
    "    print('reward=', reward)\n",
    "    print('done=', done)\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=11, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 11),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8, 1.2000000000000002)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(11))\n",
    "\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /= 10\n",
    "    action_continuous *= 4\n",
    "    action_continuous -= 2\n",
    "\n",
    "    return action, action_continuous\n",
    "\n",
    "\n",
    "get_action([0.29292667, 0.9561349, 1.0957013])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "\n",
    "update_data()\n",
    "\n",
    "len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_41896/1416897299.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:230.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 5.8002e-01, -8.1460e-01, -4.8659e+00],\n",
       "         [ 4.2210e-01, -9.0655e-01, -9.3213e-01],\n",
       "         [-7.9495e-01, -6.0667e-01, -4.9727e+00],\n",
       "         [ 9.9486e-01,  1.0124e-01, -5.8302e+00],\n",
       "         [ 2.9155e-01,  9.5656e-01, -6.6488e+00],\n",
       "         [-3.6837e-01,  9.2968e-01, -7.2113e+00],\n",
       "         [ 9.6127e-01, -2.7561e-01, -3.4661e+00],\n",
       "         [ 3.4939e-01, -9.3698e-01, -5.2368e+00],\n",
       "         [-9.3281e-01, -3.6038e-01, -8.0000e+00],\n",
       "         [ 9.7827e-01, -2.0733e-01, -6.0045e+00],\n",
       "         [ 8.6188e-01, -5.0712e-01, -6.4600e+00],\n",
       "         [-3.8155e-01,  9.2435e-01, -7.4915e+00],\n",
       "         [ 9.7981e-01, -1.9994e-01, -6.0542e+00],\n",
       "         [-9.9146e-01, -1.3039e-01, -5.2345e+00],\n",
       "         [-9.6150e-01, -2.7482e-01, -7.2963e+00],\n",
       "         [-4.0645e-01, -9.1367e-01,  4.9823e+00],\n",
       "         [-9.3281e-01, -3.6038e-01, -8.0000e+00],\n",
       "         [-2.9172e-01,  9.5650e-01, -5.2620e+00],\n",
       "         [ 9.7827e-01, -2.0732e-01, -6.0045e+00],\n",
       "         [ 7.8532e-01,  6.1909e-01, -5.9155e+00],\n",
       "         [ 9.9583e-01,  9.1269e-02, -5.7730e+00],\n",
       "         [ 6.5353e-01, -7.5690e-01, -6.6411e+00],\n",
       "         [-3.6555e-02,  9.9933e-01, -7.0982e+00],\n",
       "         [ 2.7450e-03, -1.0000e+00, -6.8941e+00],\n",
       "         [ 3.5581e-01,  9.3456e-01, -3.9560e+00],\n",
       "         [ 3.5583e-01, -9.3455e-01, -6.9688e+00],\n",
       "         [-7.0814e-03, -9.9997e-01, -7.3909e+00],\n",
       "         [-9.9943e-01,  3.3865e-02, -8.0000e+00],\n",
       "         [-9.1384e-01,  4.0607e-01, -7.7365e+00],\n",
       "         [-9.9995e-01, -1.0139e-02, -8.0000e+00],\n",
       "         [-4.2508e-01, -9.0516e-01, -4.1861e+00],\n",
       "         [ 7.5918e-01, -6.5088e-01, -4.6177e+00],\n",
       "         [-3.6554e-02,  9.9933e-01, -7.0982e+00],\n",
       "         [-9.9951e-01,  3.1323e-02, -8.0000e+00],\n",
       "         [ 7.7518e-01,  6.3174e-01, -5.9911e+00],\n",
       "         [ 1.5038e-01, -9.8863e-01,  3.1717e+00],\n",
       "         [ 9.2865e-01,  3.7096e-01, -5.7512e+00],\n",
       "         [-3.9137e-01, -9.2023e-01, -7.9009e+00],\n",
       "         [-9.3372e-01, -3.5800e-01, -8.0000e+00],\n",
       "         [ 9.9883e-01,  4.8401e-02, -2.8231e+00],\n",
       "         [ 8.8798e-01, -4.5988e-01, -3.9728e+00],\n",
       "         [ 8.6188e-01, -5.0712e-01, -6.4600e+00],\n",
       "         [ 6.5103e-01, -7.5905e-01, -6.6003e+00],\n",
       "         [ 8.6188e-01, -5.0712e-01, -6.4600e+00],\n",
       "         [-9.9951e-01,  3.1323e-02, -8.0000e+00],\n",
       "         [ 5.7076e-01,  8.2111e-01, -6.2313e+00],\n",
       "         [ 9.2865e-01,  3.7095e-01, -5.7511e+00],\n",
       "         [-4.1190e-01,  9.1123e-01, -7.5697e+00],\n",
       "         [-6.8940e-01, -7.2438e-01, -7.8640e+00],\n",
       "         [-2.2773e-01, -9.7372e-01, -3.6958e+00],\n",
       "         [ 3.5443e-01, -9.3508e-01, -6.9275e+00],\n",
       "         [-7.1390e-03, -9.9997e-01, -7.3888e+00],\n",
       "         [ 9.2865e-01,  3.7095e-01, -5.7511e+00],\n",
       "         [-6.9329e-01,  7.2066e-01, -7.7320e+00],\n",
       "         [ 7.8532e-01,  6.1909e-01, -5.9155e+00],\n",
       "         [ 6.5083e-01, -7.5923e-01, -6.5981e+00],\n",
       "         [-7.0814e-03, -9.9997e-01, -7.3909e+00],\n",
       "         [-9.3281e-01, -3.6038e-01, -8.0000e+00],\n",
       "         [-8.9620e-01,  4.4365e-01, -7.4990e+00],\n",
       "         [ 3.5458e-01, -9.3503e-01, -6.9296e+00],\n",
       "         [ 8.6168e-01, -5.0745e-01, -6.4575e+00],\n",
       "         [-9.3279e-01, -3.6042e-01, -8.0000e+00],\n",
       "         [-3.8155e-01,  9.2435e-01, -7.4915e+00],\n",
       "         [ 2.9157e-01,  9.5655e-01, -6.6487e+00]]),\n",
       " tensor([[9],\n",
       "         [6],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [9],\n",
       "         [7],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [9],\n",
       "         [0],\n",
       "         [9],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [9],\n",
       "         [0],\n",
       "         [7],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [9],\n",
       "         [9],\n",
       "         [9],\n",
       "         [0],\n",
       "         [0]]),\n",
       " tensor([[ -3.2766],\n",
       "         [ -1.3753],\n",
       "         [ -8.6740],\n",
       "         [ -3.4134],\n",
       "         [ -6.0502],\n",
       "         [ -8.9991],\n",
       "         [ -1.2833],\n",
       "         [ -4.2185],\n",
       "         [-14.0916],\n",
       "         [ -3.6530],\n",
       "         [ -4.4585],\n",
       "         [ -9.4667],\n",
       "         [ -3.7099],\n",
       "         [-11.8091],\n",
       "         [-13.5240],\n",
       "         [ -6.4406],\n",
       "         [-14.0916],\n",
       "         [ -6.2579],\n",
       "         [ -3.6531],\n",
       "         [ -3.9489],\n",
       "         [ -3.3451],\n",
       "         [ -5.1501],\n",
       "         [ -7.6260],\n",
       "         [ -7.2141],\n",
       "         [ -3.0259],\n",
       "         [ -6.3158],\n",
       "         [ -7.9548],\n",
       "         [-16.0605],\n",
       "         [-13.4065],\n",
       "         [-16.2086],\n",
       "         [ -5.7943],\n",
       "         [ -2.6372],\n",
       "         [ -7.6260],\n",
       "         [-16.0763],\n",
       "         [ -4.0610],\n",
       "         [ -3.0226],\n",
       "         [ -3.4561],\n",
       "         [-10.1373],\n",
       "         [-14.1058],\n",
       "         [ -0.8033],\n",
       "         [ -1.8106],\n",
       "         [ -4.4585],\n",
       "         [ -5.1018],\n",
       "         [ -4.4585],\n",
       "         [-16.0763],\n",
       "         [ -4.8150],\n",
       "         [ -3.4560],\n",
       "         [ -9.7154],\n",
       "         [-11.6225],\n",
       "         [ -4.6104],\n",
       "         [ -6.2620],\n",
       "         [ -7.9519],\n",
       "         [ -3.4560],\n",
       "         [-11.4431],\n",
       "         [ -3.9489],\n",
       "         [ -5.0992],\n",
       "         [ -7.9548],\n",
       "         [-14.0917],\n",
       "         [-12.8201],\n",
       "         [ -6.2646],\n",
       "         [ -4.4557],\n",
       "         [-14.0914],\n",
       "         [ -9.4667],\n",
       "         [ -6.0499]]),\n",
       " tensor([[ 3.4939e-01, -9.3698e-01, -5.2368e+00],\n",
       "         [ 3.5055e-01, -9.3654e-01, -1.5520e+00],\n",
       "         [-9.2396e-01, -3.8249e-01, -5.1877e+00],\n",
       "         [ 9.7981e-01, -1.9994e-01, -6.0542e+00],\n",
       "         [ 5.7075e-01,  8.2113e-01, -6.2314e+00],\n",
       "         [-3.6545e-02,  9.9933e-01, -6.8140e+00],\n",
       "         [ 8.8798e-01, -4.5988e-01, -3.9728e+00],\n",
       "         [ 7.1879e-02, -9.9741e-01, -5.6995e+00],\n",
       "         [-9.9951e-01,  3.1323e-02, -8.0000e+00],\n",
       "         [ 8.6188e-01, -5.0712e-01, -6.4600e+00],\n",
       "         [ 6.5103e-01, -7.5905e-01, -6.6003e+00],\n",
       "         [-3.6555e-02,  9.9933e-01, -7.0982e+00],\n",
       "         [ 8.6457e-01, -5.0251e-01, -6.5042e+00],\n",
       "         [-9.8864e-01,  1.5029e-01, -5.6323e+00],\n",
       "         [-9.9641e-01,  8.4620e-02, -7.2624e+00],\n",
       "         [-1.9643e-01, -9.8052e-01,  4.4171e+00],\n",
       "         [-9.9951e-01,  3.1323e-02, -8.0000e+00],\n",
       "         [-5.3774e-02,  9.9855e-01, -4.8446e+00],\n",
       "         [ 8.6188e-01, -5.0711e-01, -6.4600e+00],\n",
       "         [ 9.2865e-01,  3.7095e-01, -5.7511e+00],\n",
       "         [ 9.7827e-01, -2.0732e-01, -6.0045e+00],\n",
       "         [ 3.5583e-01, -9.3455e-01, -6.9688e+00],\n",
       "         [ 2.9157e-01,  9.5655e-01, -6.6487e+00],\n",
       "         [-3.5924e-01, -9.3324e-01, -7.4041e+00],\n",
       "         [ 5.1545e-01,  8.5692e-01, -3.5551e+00],\n",
       "         [-7.6848e-03, -9.9997e-01, -7.4297e+00],\n",
       "         [-3.9137e-01, -9.2023e-01, -7.9009e+00],\n",
       "         [-9.1284e-01,  4.0831e-01, -7.7346e+00],\n",
       "         [-6.9329e-01,  7.2066e-01, -7.7320e+00],\n",
       "         [-9.2932e-01,  3.6929e-01, -7.7676e+00],\n",
       "         [-6.2122e-01, -7.8364e-01, -4.6249e+00],\n",
       "         [ 5.8002e-01, -8.1460e-01, -4.8659e+00],\n",
       "         [ 2.9157e-01,  9.5655e-01, -6.6487e+00],\n",
       "         [-9.1384e-01,  4.0607e-01, -7.7365e+00],\n",
       "         [ 9.2379e-01,  3.8289e-01, -5.8173e+00],\n",
       "         [ 2.7488e-01, -9.6148e-01,  2.5502e+00],\n",
       "         [ 9.9583e-01,  9.1269e-02, -5.7730e+00],\n",
       "         [-7.1883e-01, -6.9518e-01, -8.0000e+00],\n",
       "         [-9.9943e-01,  3.3865e-02, -8.0000e+00],\n",
       "         [ 9.9440e-01, -1.0572e-01, -3.0868e+00],\n",
       "         [ 7.5918e-01, -6.5088e-01, -4.6177e+00],\n",
       "         [ 6.5103e-01, -7.5905e-01, -6.6003e+00],\n",
       "         [ 3.5458e-01, -9.3503e-01, -6.9296e+00],\n",
       "         [ 6.5103e-01, -7.5905e-01, -6.6003e+00],\n",
       "         [-9.1384e-01,  4.0607e-01, -7.7365e+00],\n",
       "         [ 7.8532e-01,  6.1909e-01, -5.9155e+00],\n",
       "         [ 9.9583e-01,  9.1257e-02, -5.7729e+00],\n",
       "         [-6.5174e-02,  9.9787e-01, -7.1863e+00],\n",
       "         [-9.1707e-01, -3.9874e-01, -8.0000e+00],\n",
       "         [-4.2508e-01, -9.0516e-01, -4.1861e+00],\n",
       "         [-7.1390e-03, -9.9997e-01, -7.3888e+00],\n",
       "         [-3.9133e-01, -9.2025e-01, -7.8988e+00],\n",
       "         [ 9.9583e-01,  9.1257e-02, -5.7729e+00],\n",
       "         [-3.8155e-01,  9.2435e-01, -7.4915e+00],\n",
       "         [ 9.2865e-01,  3.7095e-01, -5.7511e+00],\n",
       "         [ 3.5443e-01, -9.3508e-01, -6.9275e+00],\n",
       "         [-3.9137e-01, -9.2023e-01, -7.9009e+00],\n",
       "         [-9.9951e-01,  3.1324e-02, -8.0000e+00],\n",
       "         [-6.7267e-01,  7.3994e-01, -7.4662e+00],\n",
       "         [-7.0814e-03, -9.9997e-01, -7.3909e+00],\n",
       "         [ 6.5083e-01, -7.5923e-01, -6.5981e+00],\n",
       "         [-9.9951e-01,  3.1276e-02, -8.0000e+00],\n",
       "         [-3.6554e-02,  9.9933e-01, -7.0982e+00],\n",
       "         [ 5.7077e-01,  8.2111e-01, -6.2313e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.5589],\n",
       "        [0.1906],\n",
       "        [0.7206],\n",
       "        [0.5936],\n",
       "        [0.8485],\n",
       "        [0.9757],\n",
       "        [0.3763],\n",
       "        [0.6412],\n",
       "        [1.1037],\n",
       "        [0.5958],\n",
       "        [0.6672],\n",
       "        [1.0034],\n",
       "        [0.6005],\n",
       "        [0.7048],\n",
       "        [1.0028],\n",
       "        [0.9286],\n",
       "        [1.1037],\n",
       "        [0.7828],\n",
       "        [0.5958],\n",
       "        [0.6724],\n",
       "        [0.5879],\n",
       "        [0.7576],\n",
       "        [0.9325],\n",
       "        [0.9010],\n",
       "        [0.5726],\n",
       "        [0.8559],\n",
       "        [0.9686],\n",
       "        [1.0380],\n",
       "        [1.0372],\n",
       "        [1.0472],\n",
       "        [0.5861],\n",
       "        [0.4862],\n",
       "        [0.9325],\n",
       "        [1.0386],\n",
       "        [0.6829],\n",
       "        [0.6667],\n",
       "        [0.6113],\n",
       "        [1.0915],\n",
       "        [1.1034],\n",
       "        [0.3418],\n",
       "        [0.4100],\n",
       "        [0.6672],\n",
       "        [0.7535],\n",
       "        [0.6672],\n",
       "        [1.0386],\n",
       "        [0.7545],\n",
       "        [0.6113],\n",
       "        [1.0131],\n",
       "        [1.1096],\n",
       "        [0.5035],\n",
       "        [0.8513],\n",
       "        [0.9683],\n",
       "        [0.6113],\n",
       "        [1.0366],\n",
       "        [0.6724],\n",
       "        [0.7533],\n",
       "        [0.9686],\n",
       "        [1.1037],\n",
       "        [1.0146],\n",
       "        [0.8515],\n",
       "        [0.6670],\n",
       "        [1.1037],\n",
       "        [1.0034],\n",
       "        [0.8485]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ -2.6482],\n",
       "        [ -1.1666],\n",
       "        [ -7.9569],\n",
       "        [ -2.8249],\n",
       "        [ -5.3108],\n",
       "        [ -8.1130],\n",
       "        [ -0.8815],\n",
       "        [ -3.4968],\n",
       "        [-13.0738],\n",
       "        [ -2.9991],\n",
       "        [ -4.4585],\n",
       "        [ -8.5528],\n",
       "        [ -3.0521],\n",
       "        [-11.0352],\n",
       "        [-12.5834],\n",
       "        [ -5.6039],\n",
       "        [-13.0738],\n",
       "        [ -5.5549],\n",
       "        [ -2.9992],\n",
       "        [ -3.3498],\n",
       "        [ -2.7612],\n",
       "        [ -4.3113],\n",
       "        [ -6.7945],\n",
       "        [ -6.2122],\n",
       "        [ -2.5241],\n",
       "        [ -5.3614],\n",
       "        [ -6.8851],\n",
       "        [-15.0442],\n",
       "        [-12.3906],\n",
       "        [-15.1909],\n",
       "        [ -5.1421],\n",
       "        [ -2.0895],\n",
       "        [ -6.7945],\n",
       "        [-15.0599],\n",
       "        [ -3.4546],\n",
       "        [ -2.4686],\n",
       "        [ -2.8800],\n",
       "        [ -9.0327],\n",
       "        [-13.0885],\n",
       "        [ -0.4519],\n",
       "        [ -1.3342],\n",
       "        [ -3.7201],\n",
       "        [ -4.2673],\n",
       "        [ -3.7201],\n",
       "        [-15.0599],\n",
       "        [ -4.1560],\n",
       "        [ -2.8799],\n",
       "        [ -8.7902],\n",
       "        [-10.5367],\n",
       "        [ -4.0360],\n",
       "        [ -5.3131],\n",
       "        [ -6.8825],\n",
       "        [ -2.8799],\n",
       "        [-10.4599],\n",
       "        [ -3.3498],\n",
       "        [ -4.2650],\n",
       "        [ -6.8851],\n",
       "        [-13.0738],\n",
       "        [-11.8301],\n",
       "        [ -5.3153],\n",
       "        [ -3.7175],\n",
       "        [-13.0736],\n",
       "        [ -8.5528],\n",
       "        [ -5.3106]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1394.3801358473431"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play=False):\n",
    "    state = env.reset()\n",
    "    reward_sum = 0\n",
    "    over = False\n",
    "    while not over:\n",
    "        _, action_continuous = get_action(state)\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "sum([test() for _ in range(20)]) / 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            print(epoch, len(datas), sum([test() for _ in range(5)]) / 5)\n",
    "\n",
    "    torch.save(model, 'save/5.DQN_Pendulum')\n",
    "\n",
    "\n",
    "#train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAgiUlEQVR4nO3dfXBU5d038O9uNtm87oYAu0skgfSGgpEXNUhYbWvvIRI0WpG0RaSaWtSBBgpGnJIKOKW14cFW6gtGbx0B7ynGgRrUCEqeAEFKCBiIhChBH5FEYDcBzG4SyNvu7/nD5pRVxGxIcu3G72fmzJBzXbvndyD75TrXOXuOTkQERET9TK+6ACL6fmL4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREsrCZ+3atRg5ciTCw8ORmpqK/fv3qyqFiBRQEj6vv/46cnJy8Pjjj+PgwYOYOHEi0tPTUV9fr6IcIlJAp+KLpampqbjhhhvw3HPPAQC8Xi8SEhKwcOFCLF269Dtf7/V6cerUKcTExECn0/V1uUTUTSKCpqYmxMfHQ6+//NjG0E81adrb21FRUYHc3FxtnV6vR1paGsrKyi75mra2NrS1tWk/nzx5EsnJyX1eKxH1TF1dHYYPH37ZPv0ePmfOnIHH44HVavVZb7VacfTo0Uu+Ji8vD3/84x+/sb6urg4mk6lP6iQi/7ndbiQkJCAmJuY7+/Z7+PREbm4ucnJytJ+7dtBkMjF8iAJQd6ZD+j18hgwZgpCQEDidTp/1TqcTNpvtkq8xGo0wGo39UR4R9ZN+P9sVFhaGlJQUlJSUaOu8Xi9KSkpgt9v7uxwiUkTJYVdOTg6ysrIwadIkTJ48GX//+9/R0tKC+++/X0U5RKSAkvCZNWsWGhoasGLFCjgcDlx77bV49913vzEJTUQDl5LrfK6U2+2G2WyGy+XihDNRAPHns8nvdhGREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREr4HT67d+/GHXfcgfj4eOh0OmzZssWnXUSwYsUKDBs2DBEREUhLS8Mnn3zi0+fcuXOYM2cOTCYTYmNjMXfuXDQ3N1/RjhBRcPE7fFpaWjBx4kSsXbv2ku2rV6/GM888gxdeeAHl5eWIiopCeno6WltbtT5z5sxBdXU1iouLUVRUhN27d+Ohhx7q+V4QUfCRKwBACgsLtZ+9Xq/YbDZ58skntXWNjY1iNBrltddeExGRjz76SADIgQMHtD7btm0TnU4nJ0+e7NZ2XS6XABCXy3Ul5RNRL/Pns9mrcz7Hjx+Hw+FAWlqats5sNiM1NRVlZWUAgLKyMsTGxmLSpElan7S0NOj1epSXl1/yfdva2uB2u30WIgpuvRo+DocDAGC1Wn3WW61Wrc3hcMBisfi0GwwGxMXFaX2+Li8vD2azWVsSEhJ6s2wiUiAoznbl5ubC5XJpS11dneqSiOgK9Wr42Gw2AIDT6fRZ73Q6tTabzYb6+nqf9s7OTpw7d07r83VGoxEmk8lnIaLg1qvhk5SUBJvNhpKSEm2d2+1GeXk57HY7AMBut6OxsREVFRVanx07dsDr9SI1NbU3yyGiAGbw9wXNzc349NNPtZ+PHz+OyspKxMXFITExEYsXL8af//xnjB49GklJSVi+fDni4+MxY8YMAMDVV1+N6dOn48EHH8QLL7yAjo4OLFiwAHfffTfi4+N7bceIKMD5eypt586dAuAbS1ZWloh8dbp9+fLlYrVaxWg0ytSpU6WmpsbnPc6ePSuzZ8+W6OhoMZlMcv/990tTU1O3a+CpdqLA5M9nUyciojD7esTtdsNsNsPlcnH+hyiA+PPZDIqzXUQ08DB8iEgJhg8RKeH32S6i3iAi6HS50OZ0Qjo6YDCZYBw2DDqDATqdTnV51A8YPtTvPK2tOLN9O868954WPiHR0YhOTsawWbMQ+V//xQD6HmD4UL/ytLbi1P/+Lxq2boV4PP9Z39QEV3k5Lpw4gZGLFiE6OZkBNMBxzof6jYjg3K5daNi2zSd4LtbucKDuf/4HnY2N/Vsc9TuGD/UbT3Mz6t95B9LZedl+Fz7/HOfef7+fqiJVGD7Ub9rPnEH7t9w2xYcI2k6d6vuCSCmGD/Wbxr174W1r615nEQThxffkB4YP9Rt/woTBM/AxfCgwMXwGPIYPBSYRBtAAx/ChgCRer+oSqI8xfCgwcdQz4DF8KCAJD7sGPIYPBSYGz4DH8KHAxDmfAY/hQwGJh10DH8OHApMIGD0DG8OHAhNHPQMew4cCEg+7Bj6GDwUmTjgPeAwfCkwc9Qx4DB8KSDzsGvgYPhSYGDwDHsOHAhPnfAY8hg8FJB52DXwMHwpMDJ4Bj+FDAUl4hfOAx/ChwMTDrgGP4UOBiRPOAx7DhwISn14x8DF8KDDxsGvAY/hQv9HpdN3vzMOuAc+v8MnLy8MNN9yAmJgYWCwWzJgxAzU1NT59WltbkZ2djcGDByM6OhqZmZlwOp0+fWpra5GRkYHIyEhYLBY8+uij6PyO53dT8AtPSIDOYOhW39ZTp+Btbe3jikglv8KntLQU2dnZ2LdvH4qLi9HR0YFp06ahpaVF6/Pwww/j7bffxqZNm1BaWopTp05h5syZWrvH40FGRgba29uxd+9ebNiwAevXr8eKFSt6b68oIOnDw4Fujn687e2c9xngdHIF/8INDQ2wWCwoLS3FT37yE7hcLgwdOhQbN27Ez3/+cwDA0aNHcfXVV6OsrAxTpkzBtm3bcPvtt+PUqVOwWq0AgBdeeAG///3v0dDQgLCwsO/crtvthtlshsvlgslk6mn51M9cH3yA/5eXB+no+M6++shIXPP88wiLi+uHyqi3+PPZvKI5H5fLBQCI+/cvSEVFBTo6OpCWlqb1GTt2LBITE1FWVgYAKCsrw/jx47XgAYD09HS43W5UV1dfcjttbW1wu90+CwUhf+Z8aMDrcfh4vV4sXrwYN910E8aNGwcAcDgcCAsLQ2xsrE9fq9UKh8Oh9bk4eLrau9ouJS8vD2azWVsSEhJ6WjapxPChi/Q4fLKzs3HkyBEUFBT0Zj2XlJubC5fLpS11dXV9vk3qXV1fl2jp7MSJ5mZ8eO4c9tXX46PGRng4t/O91L1TD1+zYMECFBUVYffu3Rg+fLi23mazob29HY2NjT6jH6fTCZvNpvXZv3+/z/t1nQ3r6vN1RqMRRqOxJ6WSQiICr9cLp9OJ3bt3Y9s//4n9u3fD1dqKNq8XIoLJQ4dixbXXIiQkRHW51M/8Ch8RwcKFC1FYWIhdu3YhKSnJpz0lJQWhoaEoKSlBZmYmAKCmpga1tbWw2+0AALvdjieeeAL19fWwWCwAgOLiYphMJiQnJ/fGPlEA6OzsxLFjx/Dqq6/irbfeQktLC5ITEvDfw4ZhVHQ0bBERiDIYEGUwIEzPy82+j/wKn+zsbGzcuBFvvvkmYmJitDkas9mMiIgImM1mzJ07Fzk5OYiLi4PJZMLChQtht9sxZcoUAMC0adOQnJyMe++9F6tXr4bD4cCyZcuQnZ3N0c0AICJoaGjAiy++iPXr1yM2NhYPPPAAbr/9dsS63fgiLw/e9nb/LjikAcmv8MnPzwcA/PSnP/VZv27dOvz6178GAKxZswZ6vR6ZmZloa2tDeno6nn/+ea1vSEgIioqKMH/+fNjtdkRFRSErKwsrV668sj0h5bxeLw4cOIClS5fiiy++QHZ2Nn71q19hyJAh0Ov1aKquBnQ6Bg8B6MFh13cJDw/H2rVrsXbt2m/tM2LECGzdutWfTVOA83g8KCoqwpIlSzBmzBhs3rwZ48ePh+6isNHx8Iouwt8GumJerxdvvPEGFi5ciOnTp2PDhg2YMGEC9Hq97yiHIx66SI/OdhF1ERFs374dS5YswT333INly5YhOjr60p11OgYQaTjyoR4TEVRXV+ORRx7B9OnT8dhjj3178OCrwy5GD3Vh+FCPud1urFixAlarFStXrkRMTMzlX8BRD12Eh13UI16vF5s3b8bBgwexceNG7Zqty+FZLroYRz7kNxGB0+nE2rVrMWfOHKSmpnYvWHi2iy7C3wbym4igsLAQ58+fxwMPPAB9d0OFIx+6CMOH/NbS0oKNGzdixowZSExM7PbhlI5nu+giDB/yi4jg4MGDqKurQ2Zmpn9fCOVhF12Evw3kFxHBjh07kJSUhLFjx/r12r6ccPZ6vaitreW9wIMIw4f80tnZiT179mjfy/NLH458Wlpa8Kc//Qlnz57ts21Q72L4kF/q6+vxxRdfYNKkSf6PZPpo5CMiqKysRFFREfbu3csbzwcJhg/55fTp02hubsaYMWP8Dp+vTziLCL5sa8MHZ87gE7cb3h6GhtfrxZYtW9DQ0IB33nkHHo+nR+9D/YsXGZJfvvzySxgMhm/cp7tbLjrsEhHUtrRg+aFDqHG5EGUw4IEf/hCzkpIQ4keoiQjOnDmDrVu3QkSwa9cu1NfXY9iwYbyoMcBx5EN+6ejoQFRUVLcecfR1F4eBAPg/VVXaPZzdHR147uOPceTLL/1+3127duHzzz8HAJw8eRL/+te//H4P6n8MH/KbwWDo/oWFFwsJgf6iu1W6v/b8rnavF21+HjJ1dnZi8+bNSEhIQEREBAYNGoS3336bh15BgOFDfuvo6IC3B89SD42NRdxPfgIA0AH4b5sNhotGQz80mTDi39+KDzEau3XzsRMnTmDUqFF49tlnYbFY8NRTT2HIkCHf+hgmChyc8yG/hIaG4vz582hvb/f7tTq9HkPS0nB2xw54mpuRNWoUYkJD8X9Pn8awiAg8+MMfwhIeDgAYdPPNMHTjabQWiwXLly/Hhx9+CAC49tprcdttt3G+JwgwfMgvgwYNQmdnJxobG30em9Rd4YmJGDZ7Nk5u2ABDezt+MXIkfj5ypHafH51Oh+jkZFhnzOjWyKfrNh6nT59GSEgI4uLivvvWHhQQeNhFfhk2bBiio6Nx9OjRHl1Po9PrMXT6dFyVlYXQuDjo9Hro/32fZ31YGMyTJ2PEokXdfka7TqeDiOCTTz6BzWZDVFSUz32jKXBx5EN+sVgsSEhIQEVFBWbOnNmjD7k+NBSWjAyYr78e7kOH0OZwQB8Rgeirr0b01VcjJCLCr/fzeDw4cOAAxo8fz8cvBRGGD/nFYDDgRz/6EXbt2oWWlpYeH+Lo9HqEX3UVwq+66opramhowJEjR7Bs2TI++TSI8LCL/KLT6TB16lR8/vnnOHbsmOpyICLYs2cPOjo6cNNNN/FwK4gwfMgvOp0O1113HRITE7F582bl19N0dHTg9ddfx49//GNc1QujKOo/DB/yW2RkJGbPno3CwkLU1tYq+yKniGD//v0oLy/HnDlzEBoaqqQO6hmGD/lNr9fjrrvuQmRkJF566aUeXXDYGy5cuICnn34a119/PW688UYecgUZhg/1iMViQXZ2NjZu3Ih9+/b1++jH6/WisLAQe/fuRU5ODiIjI/t1+3TlGD7UI3q9Hr/4xS+QkpKC5cuXw+l09tu2ux5W+MQTT+Cee+7hRHOQYvhQj8XExGDlypWor6/H8uXL4Xa7+3ybIgKHw4ElS5YgPj4eS5YsgcHAK0aCEcOHekyn0yE5ORl/+9vf8N5772HlypVoamrqs+2JCBoaGpCTkwOn04k1a9bAYrFw1BOkGD50RXQ6HW655Rb89a9/xeuvv46lS5fizJkzvT4HJCI4ceIE5s2bh8OHDyM/Px/jxo1j8AQxhg9dMb1ej8zMTDz33HPYvn07srKycPjwYXi93isOIRFBZ2cnduzYgV/+8peora3FK6+8gilTpjB4ghzDh3pFSEgI7rjjDvzjH/9Aa2srMjMz8dRTT6G+vr5Hp+JFBB6PB5999hlyc3ORlZWFH/zgBygoKMDkyZMZPAOAToLwVv9utxtmsxkulwumbtzzhfpP17zMSy+9hHXr1sFsNmP27Nm4/fbbkZiYiIiIiMsGh9frxfnz53Hs2DFs3rwZ//znP2E0GvG73/0Os2bNQnR0NIMngPnz2WT4UJ/o7OzEp59+ildffRVbtmxBc3MzJkyYgMmTJ2PChAkYMWIEzGYzQkJC0NHRgS+//BKfffYZPvzwQ+zbtw/Hjh2DxWLB3XffjbvvvhvDhw/v2a1bqV8xfCggiAi8Xi+cTifef/99bN++HZWVlTh79iwuXLjgMx+k1+sRGRkJq9WKyZMn45ZbbsGUKVMQFxfH+/MEkT4Ln/z8fOTn52tPCrjmmmuwYsUK3HrrrQCA1tZWPPLIIygoKEBbWxvS09Px/PPPw2q1au9RW1uL+fPnY+fOnYiOjkZWVhby8vL8ulaD4RN8RAQigubmZjgcDpw7dw4tLS3weDwwGAyIjo7GkCFDYLPZEB4ezsAJUv58Nv26Omv48OFYtWoVRo8eDRHBhg0bcOedd+LQoUO45ppr8PDDD+Odd97Bpk2bYDabsWDBAsycOVN7lInH40FGRgZsNhv27t2L06dP47777kNoaCj+8pe/9HyPKeB1hYnJZOJ/GPQVuUKDBg2Sl19+WRobGyU0NFQ2bdqktX388ccCQMrKykREZOvWraLX68XhcGh98vPzxWQySVtb27duo7W1VVwul7bU1dUJAHG5XFdaPhH1IpfL1e3PZo9n8DweDwoKCtDS0gK73Y6Kigp0dHQgLS1N6zN27FgkJiairKwMAFBWVobx48f7HIalp6fD7Xajurr6W7eVl5cHs9msLQkJCT0tm4gChN/hU1VVhejoaBiNRsybNw+FhYVITk6Gw+FAWFjYNx6ja7VatWcoORwOn+Dpau9q+za5ublwuVzaUldX52/ZRBRg/P5G3pgxY1BZWQmXy4XNmzcjKysLpaWlfVGbxmg08sbgRAOM3+ETFhaGUaNGAQBSUlJw4MABPP3005g1axba29vR2NjoM/pxOp2w2WwAAJvNhv379/u8X9etGLr6ENH3wxVfteX1etHW1oaUlBSEhoaipKREa6upqUFtbS3sdjsAwG63o6qqCvX19Vqf4uJimEwmJCcnX2kpRBRE/Br55Obm4tZbb0ViYiKampqwceNG7Nq1C++99x7MZjPmzp2LnJwcxMXFwWQyYeHChbDb7ZgyZQoAYNq0aUhOTsa9996L1atXw+FwYNmyZcjOzuZhFdH3jF/hU19fj/vuuw+nT5+G2WzGhAkT8N577+GWW24BAKxZs0b7hvPFFxl2CQkJQVFREebPnw+73Y6oqChkZWVh5cqVvbtXRBTw+PUKIuo1/nw2+U09IlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhIiSsKn1WrVkGn02Hx4sXautbWVmRnZ2Pw4MGIjo5GZmYmnE6nz+tqa2uRkZGByMhIWCwWPProo+js7LySUogoyPQ4fA4cOIAXX3wREyZM8Fn/8MMP4+2338amTZtQWlqKU6dOYebMmVq7x+NBRkYG2tvbsXfvXmzYsAHr16/HihUrer4XRBR8pAeamppk9OjRUlxcLDfffLMsWrRIREQaGxslNDRUNm3apPX9+OOPBYCUlZWJiMjWrVtFr9eLw+HQ+uTn54vJZJK2trZLbq+1tVVcLpe21NXVCQBxuVw9KZ+I+ojL5er2Z7NHI5/s7GxkZGQgLS3NZ31FRQU6Ojp81o8dOxaJiYkoKysDAJSVlWH8+PGwWq1an/T0dLjdblRXV19ye3l5eTCbzdqSkJDQk7KJKID4HT4FBQU4ePAg8vLyvtHmcDgQFhaG2NhYn/VWqxUOh0Prc3HwdLV3tV1Kbm4uXC6XttTV1flbNhEFGIM/nevq6rBo0SIUFxcjPDy8r2r6BqPRCKPR2G/bI6K+59fIp6KiAvX19bj++uthMBhgMBhQWlqKZ555BgaDAVarFe3t7WhsbPR5ndPphM1mAwDYbLZvnP3q+rmrDxENfH6Fz9SpU1FVVYXKykptmTRpEubMmaP9OTQ0FCUlJdprampqUFtbC7vdDgCw2+2oqqpCfX291qe4uBgmkwnJycm9tFtEFOj8OuyKiYnBuHHjfNZFRUVh8ODB2vq5c+ciJycHcXFxMJlMWLhwIex2O6ZMmQIAmDZtGpKTk3Hvvfdi9erVcDgcWLZsGbKzs3loRfQ94lf4dMeaNWug1+uRmZmJtrY2pKen4/nnn9faQ0JCUFRUhPnz58NutyMqKgpZWVlYuXJlb5dCRAFMJyKiugh/ud1umM1muFwumEwm1eUQ0b/589nkd7uISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUYPgQkRIMHyJSguFDREowfIhICYYPESnB8CEiJRg+RKQEw4eIlGD4EJESDB8iUoLhQ0RKMHyISAmGDxEpwfAhIiUYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iUMKguoCdEBADgdrsVV0JEF+v6THZ9Ri8nKMPn7NmzAICEhATFlRDRpTQ1NcFsNl+2T1CGT1xcHACgtrb2O3cw0LjdbiQkJKCurg4mk0l1Od3GuvtXsNYtImhqakJ8fPx39g3K8NHrv5qqMpvNQfUPczGTyRSUtbPu/hWMdXd3QMAJZyJSguFDREoEZfgYjUY8/vjjMBqNqkvxW7DWzrr7V7DW7Q+ddOecGBFRLwvKkQ8RBT+GDxEpwfAhIiUYPkSkBMOHiJQIyvBZu3YtRo4cifDwcKSmpmL//v1K69m9ezfuuOMOxMfHQ6fTYcuWLT7tIoIVK1Zg2LBhiIiIQFpaGj755BOfPufOncOcOXNgMpkQGxuLuXPnorm5uU/rzsvLww033ICYmBhYLBbMmDEDNTU1Pn1aW1uRnZ2NwYMHIzo6GpmZmXA6nT59amtrkZGRgcjISFgsFjz66KPo7Ozss7rz8/MxYcIE7epfu92Obdu2BXTNl7Jq1SrodDosXrw46GrvFRJkCgoKJCwsTF555RWprq6WBx98UGJjY8XpdCqraevWrfLYY4/JG2+8IQCksLDQp33VqlViNptly5Yt8uGHH8rPfvYzSUpKkgsXLmh9pk+fLhMnTpR9+/bJ+++/L6NGjZLZs2f3ad3p6emybt06OXLkiFRWVsptt90miYmJ0tzcrPWZN2+eJCQkSElJiXzwwQcyZcoUufHGG7X2zs5OGTdunKSlpcmhQ4dk69atMmTIEMnNze2zut966y1555135NixY1JTUyN/+MMfJDQ0VI4cORKwNX/d/v37ZeTIkTJhwgRZtGiRtj4Yau8tQRc+kydPluzsbO1nj8cj8fHxkpeXp7Cq//h6+Hi9XrHZbPLkk09q6xobG8VoNMprr70mIiIfffSRAJADBw5ofbZt2yY6nU5OnjzZb7XX19cLACktLdXqDA0NlU2bNml9Pv74YwEgZWVlIvJV8Or1enE4HFqf/Px8MZlM0tbW1m+1Dxo0SF5++eWgqLmpqUlGjx4txcXFcvPNN2vhEwy196agOuxqb29HRUUF0tLStHV6vR5paWkoKytTWNm3O378OBwOh0/NZrMZqampWs1lZWWIjY3FpEmTtD5paWnQ6/UoLy/vt1pdLheA/9w1oKKiAh0dHT61jx07FomJiT61jx8/HlarVeuTnp4Ot9uN6urqPq/Z4/GgoKAALS0tsNvtQVFzdnY2MjIyfGoEguPvuzcF1bfaz5w5A4/H4/MXDwBWqxVHjx5VVNXlORwOALhkzV1tDocDFovFp91gMCAuLk7r09e8Xi8WL16Mm266CePGjdPqCgsLQ2xs7GVrv9S+dbX1laqqKtjtdrS2tiI6OhqFhYVITk5GZWVlwNYMAAUFBTh48CAOHDjwjbZA/vvuC0EVPtR3srOzceTIEezZs0d1Kd0yZswYVFZWwuVyYfPmzcjKykJpaanqsi6rrq4OixYtQnFxMcLDw1WXo1xQHXYNGTIEISEh35j9dzqdsNlsiqq6vK66LlezzWZDfX29T3tnZyfOnTvXL/u1YMECFBUVYefOnRg+fLi23mazob29HY2NjZet/VL71tXWV8LCwjBq1CikpKQgLy8PEydOxNNPPx3QNVdUVKC+vh7XX389DAYDDAYDSktL8cwzz8BgMMBqtQZs7X0hqMInLCwMKSkpKCkp0dZ5vV6UlJTAbrcrrOzbJSUlwWaz+dTsdrtRXl6u1Wy329HY2IiKigqtz44dO+D1epGamtpntYkIFixYgMLCQuzYsQNJSUk+7SkpKQgNDfWpvaamBrW1tT61V1VV+YRncXExTCYTkpOT+6z2r/N6vWhrawvomqdOnYqqqipUVlZqy6RJkzBnzhztz4Fae59QPePtr4KCAjEajbJ+/Xr56KOP5KGHHpLY2Fif2f/+1tTUJIcOHZJDhw4JAHnqqafk0KFDcuLECRH56lR7bGysvPnmm3L48GG58847L3mq/brrrpPy8nLZs2ePjB49us9Ptc+fP1/MZrPs2rVLTp8+rS3nz5/X+sybN08SExNlx44d8sEHH4jdbhe73a61d536nTZtmlRWVsq7774rQ4cO7dNTv0uXLpXS0lI5fvy4HD58WJYuXSo6nU62b98esDV/m4vPdgVb7Vcq6MJHROTZZ5+VxMRECQsLk8mTJ8u+ffuU1rNz504B8I0lKytLRL463b58+XKxWq1iNBpl6tSpUlNT4/MeZ8+eldmzZ0t0dLSYTCa5//77pampqU/rvlTNAGTdunVanwsXLshvf/tbGTRokERGRspdd90lp0+f9nmfzz//XG699VaJiIiQIUOGyCOPPCIdHR19VvdvfvMbGTFihISFhcnQoUNl6tSpWvAEas3f5uvhE0y1Xynez4eIlAiqOR8iGjgYPkSkBMOHiJRg+BCREgwfIlKC4UNESjB8iEgJhg8RKcHwISIlGD5EpATDh4iU+P+0ibuOjXGHtgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-356.57743230826736"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = torch.load('save/5.DQN_Pendulum')\n",
    "\n",
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
